import torch
from lowp.functional import truncate_bf16, truncate_fp8


def mean_error(x, y):
    return float((x-y).pow(2).mean())


x = torch.randn(5).to(device="cuda", dtype=torch.float)
print('FP32:', x)
bf16 = truncate_bf16(x, roundingMode=0)
print('BF16:', bf16)
t_fp8 = truncate_fp8(x, roundingMode=0)
print('FP8:', t_fp8)
print(mean_error(x, t_fp8))

itr = 1000
acc = torch.zeros_like(x).type(torch.float64)

for i in range(0, itr):
    t_fp8 = truncate_fp8(x, roundingMode=4, min_noise=-2**31, max_noise=2**31 - 1)
    acc += t_fp8.type(torch.float64)

t_fp8_stoch = (acc / itr).float()
print(mean_error(x, acc / itr))
print('FP8 (avg stoch. rounding):',  t_fp8_stoch)